{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import random\n",
    "import datetime\n",
    "\n",
    "import scipy\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "try:\n",
    "    import SparseSC as SC\n",
    "except ImportError:\n",
    "    raise RuntimeError(\"SparseSC is not installed. Use 'pip install -e .' or 'conda develop .' from repo root to install in dev mode\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(12345)\n",
    "np.random.seed(101101001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "pkl_file = \"../replication/smoking_fits.pkl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "                 cigsale  lnincome  beer  age15to24   retprice\n",
       "state   year                                                  \n",
       "Alabama 1970   89.800003       NaN   NaN   0.178862  39.599998\n",
       "        1971   95.400002       NaN   NaN   0.179928  42.700001\n",
       "        1972  101.099998  9.498476   NaN   0.180994  42.299999\n",
       "        1973  102.900002  9.550107   NaN   0.182060  42.099998\n",
       "        1974  108.199997  9.537163   NaN   0.183126  43.099998"
      ],
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th></th>\n      <th>cigsale</th>\n      <th>lnincome</th>\n      <th>beer</th>\n      <th>age15to24</th>\n      <th>retprice</th>\n    </tr>\n    <tr>\n      <th>state</th>\n      <th>year</th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th rowspan=\"5\" valign=\"top\">Alabama</th>\n      <th>1970</th>\n      <td>89.800003</td>\n      <td>NaN</td>\n      <td>NaN</td>\n      <td>0.178862</td>\n      <td>39.599998</td>\n    </tr>\n    <tr>\n      <th>1971</th>\n      <td>95.400002</td>\n      <td>NaN</td>\n      <td>NaN</td>\n      <td>0.179928</td>\n      <td>42.700001</td>\n    </tr>\n    <tr>\n      <th>1972</th>\n      <td>101.099998</td>\n      <td>9.498476</td>\n      <td>NaN</td>\n      <td>0.180994</td>\n      <td>42.299999</td>\n    </tr>\n    <tr>\n      <th>1973</th>\n      <td>102.900002</td>\n      <td>9.550107</td>\n      <td>NaN</td>\n      <td>0.182060</td>\n      <td>42.099998</td>\n    </tr>\n    <tr>\n      <th>1974</th>\n      <td>108.199997</td>\n      <td>9.537163</td>\n      <td>NaN</td>\n      <td>0.183126</td>\n      <td>43.099998</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {},
     "execution_count": 4
    }
   ],
   "source": [
    "smoking_df = pd.read_stata(\"../replication/smoking.dta\")\n",
    "smoking_df['year'] = smoking_df['year'].astype('int')\n",
    "smoking_df = smoking_df.set_index(['state', 'year']).sort_index()\n",
    "smoking_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "                cigsale                                                  \\\n",
       "year               1970        1971        1972        1973        1974   \n",
       "state                                                                     \n",
       "Alabama       89.800003   95.400002  101.099998  102.900002  108.199997   \n",
       "Arkansas     100.300003  104.099998  103.900002  108.000000  109.699997   \n",
       "California   123.000000  121.000000  123.500000  124.400002  126.699997   \n",
       "Colorado     124.800003  125.500000  134.300003  137.899994  132.800003   \n",
       "Connecticut  120.000000  117.599998  110.800003  109.300003  112.400002   \n",
       "\n",
       "                                                                         ...  \\\n",
       "year               1975        1976        1977        1978        1979  ...   \n",
       "state                                                                    ...   \n",
       "Alabama      111.699997  116.199997  117.099998  123.000000  121.400002  ...   \n",
       "Arkansas     114.800003  119.099998  122.599998  127.300003  126.500000  ...   \n",
       "California   127.099998  128.000000  126.400002  126.099998  121.900002  ...   \n",
       "Colorado     131.000000  134.199997  132.000000  129.199997  131.500000  ...   \n",
       "Connecticut  110.199997  113.400002  117.300003  117.500000  117.400002  ...   \n",
       "\n",
       "                                                                         \\\n",
       "year               1991        1992        1993        1994        1995   \n",
       "state                                                                     \n",
       "Alabama      107.900002  109.099998  108.500000  107.099998  102.599998   \n",
       "Arkansas     116.800003  126.000000  113.800003  108.800003  113.000000   \n",
       "California    68.699997   67.500000   63.400002   58.599998   56.400002   \n",
       "Colorado      90.199997   88.300003   88.599998   89.099998   85.400002   \n",
       "Connecticut   86.699997   83.500000   79.099998   76.599998   79.300003   \n",
       "\n",
       "                                                                        \n",
       "year               1996        1997        1998        1999       2000  \n",
       "state                                                                   \n",
       "Alabama      101.400002  104.900002  106.199997  100.699997  96.199997  \n",
       "Arkansas     110.699997  108.699997  109.500000  104.800003  99.400002  \n",
       "California    54.500000   53.799999   52.299999   47.200001  41.599998  \n",
       "Colorado      83.099998   81.300003   81.199997   79.599998  73.000000  \n",
       "Connecticut   76.000000   75.900002   75.500000   73.400002  71.400002  \n",
       "\n",
       "[5 rows x 31 columns]"
      ],
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead tr th {\n        text-align: left;\n    }\n\n    .dataframe thead tr:last-of-type th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr>\n      <th></th>\n      <th colspan=\"21\" halign=\"left\">cigsale</th>\n    </tr>\n    <tr>\n      <th>year</th>\n      <th>1970</th>\n      <th>1971</th>\n      <th>1972</th>\n      <th>1973</th>\n      <th>1974</th>\n      <th>1975</th>\n      <th>1976</th>\n      <th>1977</th>\n      <th>1978</th>\n      <th>1979</th>\n      <th>...</th>\n      <th>1991</th>\n      <th>1992</th>\n      <th>1993</th>\n      <th>1994</th>\n      <th>1995</th>\n      <th>1996</th>\n      <th>1997</th>\n      <th>1998</th>\n      <th>1999</th>\n      <th>2000</th>\n    </tr>\n    <tr>\n      <th>state</th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>Alabama</th>\n      <td>89.800003</td>\n      <td>95.400002</td>\n      <td>101.099998</td>\n      <td>102.900002</td>\n      <td>108.199997</td>\n      <td>111.699997</td>\n      <td>116.199997</td>\n      <td>117.099998</td>\n      <td>123.000000</td>\n      <td>121.400002</td>\n      <td>...</td>\n      <td>107.900002</td>\n      <td>109.099998</td>\n      <td>108.500000</td>\n      <td>107.099998</td>\n      <td>102.599998</td>\n      <td>101.400002</td>\n      <td>104.900002</td>\n      <td>106.199997</td>\n      <td>100.699997</td>\n      <td>96.199997</td>\n    </tr>\n    <tr>\n      <th>Arkansas</th>\n      <td>100.300003</td>\n      <td>104.099998</td>\n      <td>103.900002</td>\n      <td>108.000000</td>\n      <td>109.699997</td>\n      <td>114.800003</td>\n      <td>119.099998</td>\n      <td>122.599998</td>\n      <td>127.300003</td>\n      <td>126.500000</td>\n      <td>...</td>\n      <td>116.800003</td>\n      <td>126.000000</td>\n      <td>113.800003</td>\n      <td>108.800003</td>\n      <td>113.000000</td>\n      <td>110.699997</td>\n      <td>108.699997</td>\n      <td>109.500000</td>\n      <td>104.800003</td>\n      <td>99.400002</td>\n    </tr>\n    <tr>\n      <th>California</th>\n      <td>123.000000</td>\n      <td>121.000000</td>\n      <td>123.500000</td>\n      <td>124.400002</td>\n      <td>126.699997</td>\n      <td>127.099998</td>\n      <td>128.000000</td>\n      <td>126.400002</td>\n      <td>126.099998</td>\n      <td>121.900002</td>\n      <td>...</td>\n      <td>68.699997</td>\n      <td>67.500000</td>\n      <td>63.400002</td>\n      <td>58.599998</td>\n      <td>56.400002</td>\n      <td>54.500000</td>\n      <td>53.799999</td>\n      <td>52.299999</td>\n      <td>47.200001</td>\n      <td>41.599998</td>\n    </tr>\n    <tr>\n      <th>Colorado</th>\n      <td>124.800003</td>\n      <td>125.500000</td>\n      <td>134.300003</td>\n      <td>137.899994</td>\n      <td>132.800003</td>\n      <td>131.000000</td>\n      <td>134.199997</td>\n      <td>132.000000</td>\n      <td>129.199997</td>\n      <td>131.500000</td>\n      <td>...</td>\n      <td>90.199997</td>\n      <td>88.300003</td>\n      <td>88.599998</td>\n      <td>89.099998</td>\n      <td>85.400002</td>\n      <td>83.099998</td>\n      <td>81.300003</td>\n      <td>81.199997</td>\n      <td>79.599998</td>\n      <td>73.000000</td>\n    </tr>\n    <tr>\n      <th>Connecticut</th>\n      <td>120.000000</td>\n      <td>117.599998</td>\n      <td>110.800003</td>\n      <td>109.300003</td>\n      <td>112.400002</td>\n      <td>110.199997</td>\n      <td>113.400002</td>\n      <td>117.300003</td>\n      <td>117.500000</td>\n      <td>117.400002</td>\n      <td>...</td>\n      <td>86.699997</td>\n      <td>83.500000</td>\n      <td>79.099998</td>\n      <td>76.599998</td>\n      <td>79.300003</td>\n      <td>76.000000</td>\n      <td>75.900002</td>\n      <td>75.500000</td>\n      <td>73.400002</td>\n      <td>71.400002</td>\n    </tr>\n  </tbody>\n</table>\n<p>5 rows × 31 columns</p>\n</div>"
     },
     "metadata": {},
     "execution_count": 5
    }
   ],
   "source": [
    "Y = smoking_df[['cigsale']].unstack('year')\n",
    "Y_cols = Y.columns\n",
    "Y.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "(39, 31)\n[0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38]\n"
     ]
    }
   ],
   "source": [
    "T0 = 19\n",
    "i_t = 2 #unit 3, but zero-index\n",
    "treated_units = [i_t]\n",
    "control_units = [u for u in range(Y.shape[0]) if u not in treated_units]\n",
    "print(Y.shape)\n",
    "print(control_units)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "Y_names = Y.columns.get_level_values('year')\n",
    "Y_pre_names = [\"cigsale(\" + str(i) + \")\" for i in Y_names[:T0]]\n",
    "print(Y.isnull().sum().sum()) #0\n",
    "Y = Y.values\n",
    "T = Y.shape[1]\n",
    "T1 = T-T0\n",
    "Y_pre,Y_post = Y[:,:T0], Y[:,T0:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Stata: synth cigsale beer(1984(1)1988) lnincome retprice age15to24 cigsale(1988) cigsale(1980) cigsale(1975), xperiod(1980(1)1988)  trunit(3) trperiod(1989) \n",
    "\n",
    "year_ind = smoking_df.index.get_level_values('year')\n",
    "beer_pre = smoking_df.loc[np.logical_and(year_ind>=1984, year_ind<=1988),[\"beer\"]]\n",
    "Xother_pre = smoking_df.loc[np.logical_and(year_ind>=1980, year_ind<=1988), ['lnincome', 'retprice', 'age15to24']]\n",
    "X_avgs = pd.concat((beer_pre.groupby('state').mean(), \n",
    "                    Xother_pre.groupby('state').mean())\n",
    "                   , axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "#X_spot = pd.DataFrame({'cigsale_75':smoking_df.xs(1975, level='year')[\"cigsale\"], \n",
    "#                       'cigsale_80':smoking_df.xs(1980, level='year')[\"cigsale\"], \n",
    "#                       'cigsale_88':smoking_df.xs(1988, level='year')[\"cigsale\"]})\n",
    "#X_orig = pd.concat((X_avgs, X_spot), axis=1)\n",
    "#X_orig.isnull().sum().sum() #0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_full = pd.concat((X_avgs, beer_pre.unstack('year'), Xother_pre.unstack('year')), axis=1)\n",
    "X_full_names = [c[0] + \"(\" + str(c[1]) + \")\" if len(c)==2 else c for c in X_full.columns]\n",
    "X_full.isnull().sum().sum() #0\n",
    "X_full = X_full.values\n",
    "X_Y_pre = np.concatenate((X_full, Y_pre), axis=1)\n",
    "X_Y_pre_names = X_full_names + Y_pre_names\n",
    "X_Y_pre_names_arr = np.array(X_Y_pre_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_summary(full_fit, Y_pre, Y_post, Y_sc, show_noNH = True):\n",
    "    full_Y_pre_sc,full_Y_post_sc = Y_sc[:,:T0], Y_sc[:,T0:]\n",
    "    print(\"V: \" + str(np.diag(full_fit.V)))\n",
    "    print(\"V>0: \" + str(np.diag(full_fit.V)[np.diag(full_fit.V)>0]))\n",
    "    print(\"#V>0: \" + str(sum(np.diag(full_fit.V>0))))\n",
    "    full_Y_pre_effect_c = Y_pre[control_units, :] - full_Y_pre_sc[control_units, :]\n",
    "    full_Y_post_effect_c = Y_post[control_units, :] - full_Y_post_sc[control_units, :]\n",
    "    \n",
    "    print(X_Y_pre_names_arr[np.diag(full_fit.V)>0])\n",
    "\n",
    "    def print_seg_info(arr, seg_name):\n",
    "        print(\"Avg bias \" + seg_name + \": \" + str(arr.mean()))\n",
    "        print(scipy.stats.ttest_1samp(arr.flatten(), popmean=0)) \n",
    "        print(\"Avg MSE \" + seg_name + \": \" + str(np.mean(np.power(arr, 2))) )\n",
    "        print(\"Avg max abs val \" + seg_name + \":\" + str(np.mean(np.amax(np.abs(arr), axis=0))))\n",
    "    \n",
    "    print_seg_info(full_Y_pre_effect_c, \"pre\")\n",
    "    print_seg_info(full_Y_post_effect_c, \"post\")\n",
    "\n",
    "    NH_idx = 20 #1-based index including treatment is 22\n",
    "    if show_noNH:    \n",
    "        full_Y_pre_effect_c_noNH = np.delete(full_Y_pre_effect_c, NH_idx, axis=0)\n",
    "        full_Y_post_effect_c_noNH = np.delete(full_Y_post_effect_c, NH_idx, axis=0)    \n",
    "        \n",
    "        print_seg_info(full_Y_pre_effect_c_noNH, \"pre (no-NH)\")\n",
    "        print_seg_info(full_Y_post_effect_c_noNH, \"post (no-NH)\")\n"
   ]
  },
  {
   "source": [
    "Fast"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "if False:\n",
    "    fast_fit = SC.fit_fast(X_Y_pre, Y_post, treated_units=[i_t])\n",
    "    #print(len(np.diag(fast_fit.V)))\n",
    "    #print(np.diag(fast_fit.V))\n",
    "    #Y_post_sc = fast_fit.predict(Y_post)\n",
    "    #Y_pre_sc = fast_fit.predict(Y_pre)\n",
    "    #post_mse = np.mean(np.power(Y_post[control_units, :] - Y_post_sc[control_units, :], 2))\n",
    "    #pre_mse = np.mean(np.power(Y_pre[control_units, :] - Y_pre_sc[control_units, :], 2))\n",
    "    #print(pre_mse) #192.210632448\n",
    "    #print(post_mse) #129.190437803\n",
    "    #print(X_Y_pre_names_arr[fast_fit.match_space_desc>0])"
   ]
  },
  {
   "source": [
    "Full"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Start time: 2021-01-01 14:17:04\n",
      "End time: 2021-01-01 14:34:29\n"
     ]
    }
   ],
   "source": [
    "#use_est_fn = True\n",
    "#if use_est_fn:\n",
    "#    est_fit = SC.estimate_effects(\n",
    "#        outcomes=Y,\n",
    "#        unit_treatment_periods=X,\n",
    "#        covariates=None,\n",
    "#        fast = False,\n",
    "#        cf_folds = len(control_units), #sync with helper\n",
    "#        **kwargs\n",
    "#    )\n",
    "#    full_fit = \n",
    "#else: \n",
    "print(\"Start time: {}\".format(datetime.datetime.now().replace(microsecond=0)))\n",
    "full_fit = SC.fit(X_Y_pre, Y_post, treated_units=[i_t], verbose=0, progress=False, print_path=False)\n",
    "print(\"End time: {}\".format(datetime.datetime.now().replace(microsecond=0)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "V: [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 5.62581428e-06 1.67298271e-04\n 1.25447705e-04 2.83218864e-05 1.20684653e-04 2.90538206e-05\n 1.68471635e-04 1.91157505e-04 1.95644283e-04 2.87258604e-04\n 4.25739528e-04 4.62630652e-04 6.40482407e-04]\nV>0: [5.62581428e-06 1.67298271e-04 1.25447705e-04 2.83218864e-05\n 1.20684653e-04 2.90538206e-05 1.68471635e-04 1.91157505e-04\n 1.95644283e-04 2.87258604e-04 4.25739528e-04 4.62630652e-04\n 6.40482407e-04]\n#V>0: 13\n['cigsale(1976)' 'cigsale(1977)' 'cigsale(1978)' 'cigsale(1979)'\n 'cigsale(1980)' 'cigsale(1981)' 'cigsale(1982)' 'cigsale(1983)'\n 'cigsale(1984)' 'cigsale(1985)' 'cigsale(1986)' 'cigsale(1987)'\n 'cigsale(1988)']\nAvg bias pre: 0.07756219711357981\nTtest_1sampResult(statistic=0.15231397793811494, pvalue=0.8789819126134432)\nAvg MSE pre: 186.969136939326\nAvg max abs val pre:43.94306313035026\nAvg bias post: 0.036591103495529125\nTtest_1sampResult(statistic=0.0654025425110474, pvalue=0.9478822253243336)\nAvg MSE post: 142.42204905739632\nAvg max abs val post:28.952062858014358\nAvg bias pre (no-NH): -1.0717637082830929\nTtest_1sampResult(statistic=-2.592180790152355, pvalue=0.00973536548850398)\nAvg MSE pre (no-NH): 121.15513896073007\nAvg max abs val pre (no-NH):31.277865279231538\nAvg bias post (no-NH): 0.1280987232307059\nTtest_1sampResult(statistic=0.22671379737775588, pvalue=0.8207508739680021)\nAvg MSE post (no-NH): 141.44506906760296\nAvg max abs val post (no-NH):28.952062858014358\n"
     ]
    }
   ],
   "source": [
    "full_Y_sc = full_fit.predict(Y)\n",
    "print_summary(full_fit, Y_pre, Y_post, full_Y_sc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      " |>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>| \n"
     ]
    }
   ],
   "source": [
    "honest_predictions, cf_fits = SC.get_c_predictions_honest(X_Y_pre[control_units,:], Y_post[control_units,:], Y[control_units,:], \n",
    "                                                 w_pen=full_fit.initial_w_pen, v_pen=full_fit.initial_v_pen, cf_folds=38, verbose=1, progress=False, fast=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "V: [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n 0.00000000e+00 0.00000000e+00 5.62581428e-06 1.67298271e-04\n 1.25447705e-04 2.83218864e-05 1.20684653e-04 2.90538206e-05\n 1.68471635e-04 1.91157505e-04 1.95644283e-04 2.87258604e-04\n 4.25739528e-04 4.62630652e-04 6.40482407e-04]\nV>0: [5.62581428e-06 1.67298271e-04 1.25447705e-04 2.83218864e-05\n 1.20684653e-04 2.90538206e-05 1.68471635e-04 1.91157505e-04\n 1.95644283e-04 2.87258604e-04 4.25739528e-04 4.62630652e-04\n 6.40482407e-04]\n#V>0: 13\n['cigsale(1976)' 'cigsale(1977)' 'cigsale(1978)' 'cigsale(1979)'\n 'cigsale(1980)' 'cigsale(1981)' 'cigsale(1982)' 'cigsale(1983)'\n 'cigsale(1984)' 'cigsale(1985)' 'cigsale(1986)' 'cigsale(1987)'\n 'cigsale(1988)']\nAvg bias pre: -0.7350711981015192\nTtest_1sampResult(statistic=-1.3801169113351313, pvalue=0.16797842226970588)\nAvg MSE pre: 205.07282757064715\nAvg max abs val pre:44.605702450400905\nAvg bias post: -0.21665792297898678\nTtest_1sampResult(statistic=-0.3608685113352452, pvalue=0.7183652274687031)\nAvg MSE post: 164.05401118717884\nAvg max abs val post:30.973639170328777\nAvg bias pre (no-NH): -1.9151725931832326\nTtest_1sampResult(statistic=-4.336260426835046, pvalue=1.661526771635277e-05)\nAvg MSE pre (no-NH): 140.60533540082383\nAvg max abs val pre (no-NH):31.151243109452096\nAvg bias post (no-NH): -0.16072957150570982\nTtest_1sampResult(statistic=-0.26377389204835666, pvalue=0.7920768075601825)\nAvg MSE post (no-NH): 164.51287385286167\nAvg max abs val post (no-NH):30.973639170328777\n"
     ]
    }
   ],
   "source": [
    "full_Y_sc[control_units,:] = honest_predictions\n",
    "print_summary(full_fit, Y_pre, Y_post, full_Y_sc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "(39, 31)"
      ]
     },
     "metadata": {},
     "execution_count": 21
    }
   ],
   "source": [
    "full_Y_sc.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "honest_predictions_df = pd.DataFrame(Y[control_units,:] - honest_predictions, columns=Y_cols, index=pd.Index(np.array(control_units)+1, name=\"state\")).stack(level=\"year\")\n",
    "honest_predictions_df.to_stata(\"smoking_sparsesc.dta\")"
   ]
  },
  {
   "source": [
    "# Full - flat\n",
    "Since we don't fit v, we don't have to do out-of-sample refitting"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_fit_flat = SC._fit_fast_inner(X_Y_pre, X_Y_pre, Y_post, V=np.repeat(1,X_Y_pre.shape[1]), treated_units=[i_t])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "V: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\nV>0: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]\n#V>0: 55\nAvg bias pre: 0.003513363039932098\nTtest_1sampResult(statistic=0.05033751319674089, pvalue=0.959867371950055)\nAvg MSE pre: 3.5123625231089566\nAvg effect post: -0.40206830732880827\nTtest_1sampResult(statistic=-0.7260495769207599, pvalue=0.46818168129036175)\nAvg MSE post: 139.69517126549485\n['beer' 'lnincome' 'retprice' 'age15to24' 'beer(1984)' 'beer(1985)'\n 'beer(1986)' 'beer(1987)' 'beer(1988)' 'lnincome(1980)' 'lnincome(1981)'\n 'lnincome(1982)' 'lnincome(1983)' 'lnincome(1984)' 'lnincome(1985)'\n 'lnincome(1986)' 'lnincome(1987)' 'lnincome(1988)' 'retprice(1980)'\n 'retprice(1981)' 'retprice(1982)' 'retprice(1983)' 'retprice(1984)'\n 'retprice(1985)' 'retprice(1986)' 'retprice(1987)' 'retprice(1988)'\n 'age15to24(1980)' 'age15to24(1981)' 'age15to24(1982)' 'age15to24(1983)'\n 'age15to24(1984)' 'age15to24(1985)' 'age15to24(1986)' 'age15to24(1987)'\n 'age15to24(1988)' 'cigsale(1970)' 'cigsale(1971)' 'cigsale(1972)'\n 'cigsale(1973)' 'cigsale(1974)' 'cigsale(1975)' 'cigsale(1976)'\n 'cigsale(1977)' 'cigsale(1978)' 'cigsale(1979)' 'cigsale(1980)'\n 'cigsale(1981)' 'cigsale(1982)' 'cigsale(1983)' 'cigsale(1984)'\n 'cigsale(1985)' 'cigsale(1986)' 'cigsale(1987)' 'cigsale(1988)']\n"
     ]
    }
   ],
   "source": [
    "full_flat_Y_sc = full_fit_flat.predict(Y)\n",
    "print_summary(full_fit_flat, Y_pre, Y_post, full_flat_Y_sc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "source": [
    "write-out"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Write out\n",
    "with open(pkl_file, \"wb\" ) as output_file:\n",
    "    pickle.dump( (full_fit),  output_file) #full_fit_flat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Read back\n",
    "with open(pkl_file, \"rb\" ) as input_file:\n",
    "    (full_fit) = pickle.load(input_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.6.8 64-bit ('SparseSC_36': conda)",
   "metadata": {
    "interpreter": {
     "hash": "d5e4199e480c30e65d4fb20d3cd9d777774bdb7741bde6dbd2b401b3aff7fdac"
    }
   }
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}